package handler

import (
	"context"
	"github.com/DiracLee/dires-go/app/connection"
	"github.com/DiracLee/dires-go/app/payload"
	"github.com/DiracLee/dires-go/app/server"
	"github.com/DiracLee/dires-go/logger"
	"github.com/DiracLee/dires-go/tcp/parser"
	"io"
	"net"
	"strings"
	"sync"
	"sync/atomic"
)

type DiresHandler struct {
	activeConn sync.Map // as an atomic interface
	server     server.Server
	closed     uint32
}

var (
	unknownErrPayloadBytes = []byte("-ERR unknown\r\n")
)

func (handler *DiresHandler) Handle(ctx context.Context, c net.Conn) {
	if atomic.LoadUint32(&handler.closed) != 0 {
		// Closed handler refuses new connection.
		_ = c.Close()
	}
	conn := connection.New(c)
	handler.activeConn.Store(conn, 1)

	ch := parser.ParseStream(c)
	for result := range ch {
		if result.Err != nil {
			if result.Err == io.EOF ||
				result.Err == io.ErrUnexpectedEOF ||
				strings.Contains(result.Err.Error(), "use of closed network connection") {
				// connection closed
				handler.closeClient(conn)
				logger.Infof("connection (%v) closed", conn.RemoteAddr().String())
				return
			}
			// protocol err
			errPayload := payload.NewErrPayload(result.Err.Error())
			_, err := conn.Write(errPayload.Bytes())
			if err != nil {
				handler.closeClient(conn)
				logger.Info("connection closed: " + conn.RemoteAddr().String())
				return
			}
			continue
		}
		if result.Data == nil {
			logger.Error("empty result")
			continue
		}
		r, ok := result.Data.(*payload.MultiBulkPayload)
		if !ok {
			logger.Error("require multi bulk payload")
			continue
		}
		result := handler.server.Execute(conn, r.Args)
		payloadBytes := unknownErrPayloadBytes
		if result != nil {
			payloadBytes = result.Bytes()
		}
		_, err := conn.Write(payloadBytes)
		if err != nil {
			logger.Error("conn failed to write bytes")
			continue
		}
	}
}

func (handler *DiresHandler) Close() error {
	logger.Info("shut down handler...")
	if !atomic.CompareAndSwapUint32(&handler.closed, 0, 1) {
		// avoid closing a closed handler
		return nil
	}
	handler.activeConn.Range(func(key, value interface{}) bool {
		conn, ok := key.(connection.Connection)
		if !ok {
			logger.Error("[Handler::Close] unexpected connection detected")
		}
		_ = conn.Close()
		return true
	})
	handler.server.Close()
	return nil
}

func (handler *DiresHandler) closeClient(conn connection.Connection) {
	_ = conn.Close()
	handler.server.Disconnect(conn)
	handler.activeConn.Delete(conn)
}
